# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

START = """
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This is a file auto-generated by rvv_cpp_util_header_generator.py

#ifndef CORALNPU_TEST_UTILS_RVV_CPP_UTIL_H_
#define CORALNPU_TEST_UTILS_RVV_CPP_UTIL_H_

#include <riscv_vector.h>
#include <stdint.h>
#include <type_traits>

enum Lmul {
  INVALID,
  MF4,
  MF2,
  M1,
  M2,
  M4,
  M8,
};

Lmul Widen(Lmul lmul) {
  switch(lmul) {
    case Lmul::MF4: return Lmul::MF2;
    case Lmul::MF2: return Lmul::M1;
    case Lmul::M1:  return Lmul::M2;
    case Lmul::M2:  return Lmul::M4;
    case Lmul::M4:  return Lmul::M8;
    default:
      return Lmul::INVALID;
  }
  return Lmul::INVALID;
}

template <typename T>
struct ScalarWidenTraits {};

template<> struct ScalarWidenTraits<uint8_t> { using Type = uint16_t; };
template<> struct ScalarWidenTraits<uint16_t> { using Type = uint32_t; };
template<> struct ScalarWidenTraits<int8_t> { using Type = int16_t; };
template<> struct ScalarWidenTraits<int16_t> { using Type = int32_t; };
template<typename T> using WidenType = typename ScalarWidenTraits<T>::Type;

enum SameTypeBinaryOp {
  VADD,
  VSADD,
  VSUB,
  VSSUB,
  VRSUB,
  VMUL,
  VMULH,
  VMIN,
  VMAX,
  VAND,
  VOR,
  VXOR,
};
"""

BITCOUNTS = [8, 16, 32]
SIGNED = [False, True]
LMULS = ["MF4", "MF2", "M1", "M2", "M4", "M8"]

def all_bitcount_lmuls():
    return [
        (8, "MF4"),
        (8, "MF2"),
        (8, "M1"),
        (8, "M2"),
        (8, "M4"),
        (8, "M8"),
        (16, "MF2"),
        (16, "M1"),
        (16, "M2"),
        (16, "M4"),
        (16, "M8"),
        (32, "M1"),
        (32, "M2"),
        (32, "M4"),
        (32, "M8"),
    ]

def all_signed_bitcounts_lmuls():
    return itertools.product(SIGNED, all_bitcount_lmuls())

SAME_TYPE_BINARY_OPS = [
    "VADD",
    "VSADD",
    "VSUB",
    "VSSUB",
    "VRSUB",
    "VMUL",
    "VMULH",
    "VMIN",
    "VMAX",
    "VAND",
    "VOR",
    "VXOR",
]

def same_type_binary_op_trait(bit_count, signed, lmul):
    unsigned = "" if signed else "u"
    ui = "i" if signed else "u"
    scalar_type = f"{unsigned}int{bit_count}" # {u}int32
    base_type = f"{ui}{bit_count}{lmul.lower()}"
    return f"""
template<> struct SameTypeBinaryOpTraits<{scalar_type}_t, Lmul::{lmul}> {{
  static constexpr auto vadd_vv = __riscv_vadd_vv_{base_type};
  static constexpr auto vadd_vx = __riscv_vadd_vx_{base_type};
  static constexpr auto vsadd_vv = __riscv_vsadd{unsigned}_vv_{base_type};
  static constexpr auto vsadd_vx = __riscv_vsadd{unsigned}_vx_{base_type};
  static constexpr auto vsub_vv = __riscv_vsub_vv_{base_type};
  static constexpr auto vsub_vx = __riscv_vsub_vx_{base_type};
  static constexpr auto vssub_vv = __riscv_vssub{unsigned}_vv_{base_type};
  static constexpr auto vssub_vx = __riscv_vssub{unsigned}_vx_{base_type};
  static constexpr auto vrsub_vx = __riscv_vrsub_vx_{base_type};
  static constexpr auto vmul_vv = __riscv_vmul_vv_{base_type};
  static constexpr auto vmul_vx = __riscv_vmul_vx_{base_type};
  static constexpr auto vmulh_vv = __riscv_vmulh{unsigned}_vv_{base_type};
  static constexpr auto vmulh_vx = __riscv_vmulh{unsigned}_vx_{base_type};
  static constexpr auto vmin_vv = __riscv_vmin{unsigned}_vv_{base_type};
  static constexpr auto vmin_vx = __riscv_vmin{unsigned}_vx_{base_type};
  static constexpr auto vmax_vv = __riscv_vmax{unsigned}_vv_{base_type};
  static constexpr auto vmax_vx = __riscv_vmax{unsigned}_vx_{base_type};
  static constexpr auto vand_vv = __riscv_vand_vv_{base_type};
  static constexpr auto vand_vx = __riscv_vand_vx_{base_type};
  static constexpr auto vor_vv = __riscv_vor_vv_{base_type};
  static constexpr auto vor_vx = __riscv_vor_vx_{base_type};
  static constexpr auto vxor_vv = __riscv_vxor_vv_{base_type};
  static constexpr auto vxor_vx = __riscv_vxor_vx_{base_type};
}};
"""

MIXED_SIGN_SAME_WIDTH_TYPE_BINARY_OPS = [
    "VSLL",
    "VSRA",
    "VSRL",
    "VMULHSU",
]

def mixed_sign_same_width_type_binary_op_trait(bit_count, signed, lmul):
    unsigned = "" if signed else "u"
    ui = "i" if signed else "u"
    scalar_type = f"{unsigned}int{bit_count}" # {u}int32
    base_type = f"{ui}{bit_count}{lmul.lower()}"
    trait = f"""
template<> struct MixedSignSameWidthTypeBinaryOpTraits<{scalar_type}_t, Lmul::{lmul}> {{
  static constexpr auto vsll_vv = __riscv_vsll_vv_{base_type};
  static constexpr auto vsll_vx = __riscv_vsll_vx_{base_type};"""
    if signed:
      trait += f"""
  static constexpr auto vsra_vv = __riscv_vsra_vv_{base_type};
  static constexpr auto vsra_vx = __riscv_vsra_vx_{base_type};
  static constexpr auto vmulhsu_vv = __riscv_vmulhsu_vv_{base_type};
  static constexpr auto vmulhsu_vx = __riscv_vmulhsu_vx_{base_type};"""
    else:
      trait += f"""
  static constexpr auto vsrl_vv = __riscv_vsrl_vv_{base_type};
  static constexpr auto vsrl_vx = __riscv_vsrl_vx_{base_type};"""
    return trait + "\n};\n"


def camel_case(x):
  return x[0] + x[1:].lower()

def main():
    header = [START]

    # Generate types
    header.append("template<typename T, Lmul lmul> struct RvvTypeTraits {};")
    for signed, (bit_count, lmul) in all_signed_bitcounts_lmuls():
        unsigned = "" if signed else "u"
        scalar_type = f"{unsigned}int{bit_count}" # {u}int32
        vector_type = f"v{scalar_type}{lmul.lower()}_t"
        header.append(f"template<> struct RvvTypeTraits<{scalar_type}_t, Lmul::{lmul}> {{ using type = {vector_type}; }};")
    header.append("template<typename T, Lmul lmul>\nusing RvvType = typename RvvTypeTraits<T, lmul>::type;\n")

    # Generate loads
    header.append("template<typename T, Lmul lmul> struct VleTraits {};")
    for signed, (bit_count, lmul) in all_signed_bitcounts_lmuls():
        unsigned = "" if signed else "u"
        ui = "i" if signed else "u"
        scalar_type = f"{unsigned}int{bit_count}" # {u}int32
        load_fn = f"__riscv_vle{bit_count}_v_{ui}{bit_count}{lmul.lower()}"
        header.append(f"template<> struct VleTraits<{scalar_type}_t, Lmul::{lmul}> {{ static constexpr auto fn = {load_fn}; }};")
    header.append("template<typename T, Lmul lmul> RvvType<T, lmul> Vle(const T* ptr, size_t vl){ return VleTraits<T, lmul>::fn(ptr, vl); }\n")

    # Generate stores
    header.append("template<typename T, Lmul lmul> struct VseTraits {};")
    for signed, (bit_count, lmul) in all_signed_bitcounts_lmuls():
        unsigned = "" if signed else "u"
        ui = "i" if signed else "u"
        scalar_type = f"{unsigned}int{bit_count}" # {u}int32
        store_fn = f"__riscv_vse{bit_count}_v_{ui}{bit_count}{lmul.lower()}"
        header.append(f"template<> struct VseTraits<{scalar_type}_t, Lmul::{lmul}> {{ static constexpr auto fn = {store_fn}; }};")
    header.append("template<typename T, Lmul lmul> void Vse(T* ptr, RvvType<T, lmul> v, size_t vl) { VseTraits<T, lmul>::fn(ptr, v, vl); }\n")

    # Generate binary ops with same sign and width
    header.append("template<typename T, Lmul lmul> struct SameTypeBinaryOpTraits {};")
    for signed, (bit_count, lmul) in all_signed_bitcounts_lmuls():
        header.append(same_type_binary_op_trait(bit_count, signed, lmul))
    for binary_op in SAME_TYPE_BINARY_OPS:
        if binary_op != "VRSUB":
            header.append(f"""
template<typename T, Lmul lmul> RvvType<T, lmul>
{camel_case(binary_op)}(RvvType<T, lmul> vs1, RvvType<T, lmul> vs2, size_t vl) {{
  return SameTypeBinaryOpTraits<T, lmul>::{binary_op.lower()}_vv(vs1, vs2, vl);
}}""")
        header.append(f"""
template<typename T, Lmul lmul> RvvType<T, lmul>
{camel_case(binary_op)}(RvvType<T, lmul> vs1, T xs2, size_t vl) {{
  return SameTypeBinaryOpTraits<T, lmul>::{binary_op.lower()}_vx(vs1, xs2, vl);
}}""")

    # Generate binary ops with different sign, same width
    header.append("template<typename T, Lmul lmul> struct MixedSignSameWidthTypeBinaryOpTraits {};")
    for signed, (bit_count, lmul) in all_signed_bitcounts_lmuls():
        header.append(mixed_sign_same_width_type_binary_op_trait(bit_count, signed, lmul))
    for binary_op in MIXED_SIGN_SAME_WIDTH_TYPE_BINARY_OPS:
        header.append(f"""
template<typename T, Lmul lmul> RvvType<T, lmul>
{camel_case(binary_op)}(RvvType<T, lmul> vs1, RvvType<std::make_unsigned_t<T>, lmul> vs2, size_t vl) {{
  return MixedSignSameWidthTypeBinaryOpTraits<T, lmul>::{binary_op.lower()}_vv(vs1, vs2, vl);
}}""")
        header.append(f"""
template<typename T, Lmul lmul> RvvType<T, lmul>
{camel_case(binary_op)}(RvvType<T, lmul> vs1, std::make_unsigned_t<T> xs2, size_t vl) {{
  return MixedSignSameWidthTypeBinaryOpTraits<T, lmul>::{binary_op.lower()}_vx(vs1, xs2, vl);
}}""")

    header.append("#endif  // CORALNPU_TEST_UTILS_RVV_CPP_UTIL_H_")
    print("\n".join(header))



if __name__ == '__main__':
    main()